{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "from torch.utils.data import DataLoader\n",
    "from torchvision import datasets\n",
    "from torchvision.transforms import ToTensor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "# working with data\n",
    "\n",
    "# Download training data from open datasets.\n",
    "training_data = datasets.FashionMNIST(\n",
    "    root=\"data\",\n",
    "    train=True,\n",
    "    download=True,\n",
    "    transform=ToTensor(),\n",
    ")\n",
    "\n",
    "# Download test data from open datasets.\n",
    "test_data = datasets.FashionMNIST(\n",
    "    root=\"data\",\n",
    "    train=False,\n",
    "    download=True,\n",
    "    transform=ToTensor(),\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "metadata": {}
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<torch.utils.data.dataloader.DataLoader object at 0x12b989970>\n",
      "Shape of X [N, C, H, W]: torch.Size([64, 1, 28, 28])\n",
      "Shape of y: torch.Size([64]) torch.int64\n"
     ]
    }
   ],
   "source": [
    "batch_size = 64\n",
    "\n",
    "# Create data loaders.\n",
    "train_dataloader = DataLoader(training_data, batch_size=batch_size)\n",
    "test_dataloader = DataLoader(test_data, batch_size=batch_size)\n",
    "\n",
    "print(train_dataloader)\n",
    "\n",
    "for X, y in train_dataloader:\n",
    "    print(f\"Shape of X [N, C, H, W]: {X.shape}\")\n",
    "    print(f\"Shape of y: {y.shape} {y.dtype}\")\n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "metadata": {}
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using mps device\n",
      "NeuralNetwork(\n",
      "  (flatten): Flatten(start_dim=1, end_dim=-1)\n",
      "  (linear_relu_stack): Sequential(\n",
      "    (0): Linear(in_features=784, out_features=512, bias=True)\n",
      "    (1): ReLU()\n",
      "    (2): Linear(in_features=512, out_features=512, bias=True)\n",
      "    (3): ReLU()\n",
      "    (4): Linear(in_features=512, out_features=10, bias=True)\n",
      "  )\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "\n",
    "# Creating Model\n",
    "\n",
    "# Get cpu, gpu or mps device for training.\n",
    "device = (\n",
    "    \"cuda\"\n",
    "    if torch.cuda.is_available()\n",
    "    else \"mps\"\n",
    "    if torch.backends.mps.is_available()\n",
    "    else \"cpu\"\n",
    ")\n",
    "print(f\"Using {device} device\")\n",
    "\n",
    "# Define model\n",
    "class NeuralNetwork(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.flatten = nn.Flatten()\n",
    "        self.linear_relu_stack = nn.Sequential(\n",
    "            nn.Linear(28*28, 512),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(512, 512),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(512, 10)\n",
    "        )\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.flatten(x)\n",
    "        logits = self.linear_relu_stack(x)\n",
    "        return logits\n",
    "\n",
    "model = NeuralNetwork().to(device)\n",
    "print(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "# Optimizing the Model Parameters\n",
    "\n",
    "loss_fn = nn.CrossEntropyLoss()\n",
    "optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "def train(dataloader, model, loss_fn, optimizer):\n",
    "    size = len(dataloader.dataset)\n",
    "    model.train()\n",
    "    for batch, (X, y) in enumerate(dataloader):\n",
    "        X, y = X.to(device), y.to(device)\n",
    "\n",
    "        # Compute prediction error\n",
    "        pred = model(X)\n",
    "        loss = loss_fn(pred, y)\n",
    "\n",
    "        # Backpropagation\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        optimizer.zero_grad()\n",
    "\n",
    "        if batch % 100 == 0:\n",
    "            loss, current = loss.item(), (batch + 1) * len(X)\n",
    "            print(f\"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "def test(dataloader, model, loss_fn):\n",
    "    size = len(dataloader.dataset)\n",
    "    num_batches = len(dataloader)\n",
    "    model.eval()\n",
    "    test_loss, correct = 0, 0\n",
    "    with torch.no_grad():\n",
    "        for X, y in dataloader:\n",
    "            X, y = X.to(device), y.to(device)\n",
    "            pred = model(X)\n",
    "            test_loss += loss_fn(pred, y).item()\n",
    "            correct += (pred.argmax(1) == y).type(torch.float).sum().item()\n",
    "    test_loss /= num_batches\n",
    "    correct /= size\n",
    "    print(f\"Test Error: \\n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \\n\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "metadata": {}
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1\n",
      "-------------------------------\n",
      "loss: 2.298593  [   64/60000]\n",
      "loss: 2.298250  [ 6464/60000]\n",
      "loss: 2.303465  [12864/60000]\n",
      "loss: 2.308721  [19264/60000]\n",
      "loss: 2.295448  [25664/60000]\n",
      "loss: 2.311263  [32064/60000]\n",
      "loss: 2.299534  [38464/60000]\n",
      "loss: 2.303914  [44864/60000]\n",
      "loss: 2.295982  [51264/60000]\n",
      "loss: 2.301181  [57664/60000]\n",
      "Test Error: \n",
      " Accuracy: 12.5%, Avg loss: 2.304109 \n",
      "\n",
      "Epoch 2\n",
      "-------------------------------\n",
      "loss: 2.298593  [   64/60000]\n",
      "loss: 2.298250  [ 6464/60000]\n",
      "loss: 2.303465  [12864/60000]\n",
      "loss: 2.308721  [19264/60000]\n",
      "loss: 2.295448  [25664/60000]\n",
      "loss: 2.311263  [32064/60000]\n",
      "loss: 2.299534  [38464/60000]\n",
      "loss: 2.303914  [44864/60000]\n",
      "loss: 2.295982  [51264/60000]\n",
      "loss: 2.301181  [57664/60000]\n",
      "Test Error: \n",
      " Accuracy: 12.5%, Avg loss: 2.304109 \n",
      "\n",
      "Epoch 3\n",
      "-------------------------------\n",
      "loss: 2.298593  [   64/60000]\n",
      "loss: 2.298250  [ 6464/60000]\n",
      "loss: 2.303465  [12864/60000]\n",
      "loss: 2.308721  [19264/60000]\n",
      "loss: 2.295448  [25664/60000]\n",
      "loss: 2.311263  [32064/60000]\n",
      "loss: 2.299534  [38464/60000]\n",
      "loss: 2.303914  [44864/60000]\n",
      "loss: 2.295982  [51264/60000]\n",
      "loss: 2.301181  [57664/60000]\n",
      "Test Error: \n",
      " Accuracy: 12.5%, Avg loss: 2.304109 \n",
      "\n",
      "Epoch 4\n",
      "-------------------------------\n",
      "loss: 2.298593  [   64/60000]\n",
      "loss: 2.298250  [ 6464/60000]\n",
      "loss: 2.303465  [12864/60000]\n",
      "loss: 2.308721  [19264/60000]\n",
      "loss: 2.295448  [25664/60000]\n",
      "loss: 2.311263  [32064/60000]\n",
      "loss: 2.299534  [38464/60000]\n",
      "loss: 2.303914  [44864/60000]\n",
      "loss: 2.295982  [51264/60000]\n",
      "loss: 2.301181  [57664/60000]\n",
      "Test Error: \n",
      " Accuracy: 12.5%, Avg loss: 2.304109 \n",
      "\n",
      "Epoch 5\n",
      "-------------------------------\n",
      "loss: 2.298593  [   64/60000]\n",
      "loss: 2.298250  [ 6464/60000]\n",
      "loss: 2.303465  [12864/60000]\n",
      "loss: 2.308721  [19264/60000]\n",
      "loss: 2.295448  [25664/60000]\n",
      "loss: 2.311263  [32064/60000]\n",
      "loss: 2.299534  [38464/60000]\n",
      "loss: 2.303914  [44864/60000]\n",
      "loss: 2.295982  [51264/60000]\n",
      "loss: 2.301181  [57664/60000]\n",
      "Test Error: \n",
      " Accuracy: 12.5%, Avg loss: 2.304109 \n",
      "\n",
      "Done!\n"
     ]
    }
   ],
   "source": [
    "epochs = 5\n",
    "for t in range(epochs):\n",
    "    print(f\"Epoch {t+1}\\n-------------------------------\")\n",
    "    train(train_dataloader, model, loss_fn, optimizer)\n",
    "    test(test_dataloader, model, loss_fn)\n",
    "print(\"Done!\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "metadata": {}
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Saved PyTorch Model State to model.pth\n"
     ]
    }
   ],
   "source": [
    "# Saving Models\n",
    "\n",
    "torch.save(model.state_dict(), \"model.pth\")\n",
    "print(\"Saved PyTorch Model State to model.pth\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "metadata": {}
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Loading Models\n",
    "\n",
    "model = NeuralNetwork().to(device)\n",
    "model.load_state_dict(torch.load(\"model.pth\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "metadata": {}
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Predicted: \"Sneaker\", Actual: \"Ankle boot\"\n"
     ]
    }
   ],
   "source": [
    "# 预测\n",
    "\n",
    "classes = [\n",
    "    \"T-shirt/top\",\n",
    "    \"Trouser\",\n",
    "    \"Pullover\",\n",
    "    \"Dress\",\n",
    "    \"Coat\",\n",
    "    \"Sandal\",\n",
    "    \"Shirt\",\n",
    "    \"Sneaker\",\n",
    "    \"Bag\",\n",
    "    \"Ankle boot\",\n",
    "]\n",
    "\n",
    "model.eval()\n",
    "x, y = test_data[0][0], test_data[0][1]\n",
    "with torch.no_grad():\n",
    "    x = x.to(device)\n",
    "    pred = model(x)\n",
    "    predicted, actual = classes[pred[0].argmax(0)], classes[y]\n",
    "    print(f'Predicted: \"{predicted}\", Actual: \"{actual}\"')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
